import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import time
import scipy.stats as st
import matplotlib.pyplot as plt
import subprocess
from sklearn.preprocessing import MinMaxScaler
from scipy.interpolate import griddata
import shutil

times = 12
fracturepoints = 2500
tracerpoints = 61
NPar = 50

def create_folder(folder_path):
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)
        print(f"delete old files: {folder_path}")
    
    os.makedirs(folder_path)
    print(f"creat new files: {folder_path}")


def HF_model(modelname, main_dir, x, idf, k): 
    scaler = MinMaxScaler(feature_range=(-6, -4))       
    scaler = scaler.fit(x.T)
    x1 = scaler.transform(x.T)
    folder_path11 = './Mengdi/seismicity/k={}'.format(k)  
    create_folder(folder_path11)    
    x = x1.T
    
    mesh = pd.read_csv(r'mesh.csv', engine="python")
    mesh = np.array(mesh)
    mesh_x = mesh[:, 0]
    mesh_y = mesh[:, 1]
    mesh_x1 = pow(10, x1[:])

    fig, axs = plt.subplots(10, 5, figsize=(24, 36))  
    axs = axs.flatten() 

    for p in range(NPar):                    
        xx1 = np.linspace(0, 500, 11)   
        yy1 = np.linspace(0, 500, 11)
        xxx1, yyy1 = np.meshgrid(xx1, yy1)
        
        z = griddata((mesh_x, mesh_y), mesh_x1[:, p], (xxx1, yyy1), method='linear')
        z = np.nan_to_num(z) 
        
        contourf = axs[p].contourf(xx1, yy1, z, levels = 100, alpha =0.5, cmap = 'jet')           
        axs[p].set_title(f'Plot {p+1}')  
        axs[p].set_aspect('equal')
    plt.colorbar(contourf, label = 'permeability')
    plt.tight_layout()    
    plt.savefig(r'./Mengdi/seismicity/k={}/permeability_plots.png'.format(k), dpi=300)    
    plt.close(fig)

    save_x = np.concatenate((mesh[:, :2], x1), axis = 1)
    np.savetxt(r'./Mengdi/seismicity/k={}/x2500.txt'.format(k), save_x, fmt='%.3f')
    plt.clf()
   
    print('Biot_run...')   
    for j in range(0, x.shape[0]):     #0, x.shape[0]                                                         
        workspace = main_dir + "/HF_in_chain" + str(idf[j])
        f = open(workspace + "/FRACTURE1.inp", 'r')
        f_new = open(workspace + "/FRACTURE.inp", 'w')
        X=x[j,:]

        X=pow(10, X[:])
        XX = np.dstack((X, X, X, X, X, X, X, X, X, X))
        XX = XX.reshape(-1)

        i = 0
        while i < 25000:
            for line in f:
                if "aaaaaaaaaa" in line:
                    line = line.replace("aaaaaaaaaa",str("{0:.4E}".format(XX[i])))
                    i = i + 1
                f_new.write(line)
        f.close()
        f_new.close()
        
    TOUGHBiotName = "Toughbiot1.bat"
    print('toghbiot for chains')
    p=subprocess.Popen(TOUGHBiotName,bufsize=2048,shell=True,stdin=subprocess.PIPE,\
                       stdout=subprocess.PIPE,close_fds=True)
    p.wait()
    result_str=p.stdout.read()  
    print('done')
    dataext = "dataext.bat"
    subprocess.call(dataext)
    fy1 = np.zeros([x.shape[0], fracturepoints])
    
    for jjj in range(0, x.shape[0]):   
        workspace = main_dir + "/HF_in_chain" + str(idf[jjj])
        f = open(workspace + "/OUTPUT_DATA_RES.dat", 'r') 
        lines = f.readlines()
        p_list = []
        a = [] 
        for i in range(len(lines)):
            if lines[i][:4] == 'VARI' or lines[i][:4] == 'ZONE':
                p_list.append(a)
                a = []
            else:
                a.append(np.array(lines[i].split()[4], dtype=np.float32))
        p_list.append(a)
        p_list2 = []
        for i in range(len(p_list)):
            if p_list[i] != []:
                p_list2.append(p_list[i])
        pressures = np.array(p_list2)[:times,:]
        kongxiyafenbu = pressures.reshape(-1, 2500)

        linjiefenbu = np.loadtxt('linjiefenbu38.txt')
        linjiefenbu = np.array(linjiefenbu) 
        
        weizhen_calculate = np.zeros((kongxiyafenbu.shape[0], 2500))
        for ii in range(0, kongxiyafenbu.shape[0]):
            for jj in range(0, kongxiyafenbu.shape[1]):
                if float(kongxiyafenbu[ii, jj]) >= linjiefenbu[jj]:
                    weizhen_calculate[ii, jj] = 1
                else:
                    weizhen_calculate[ii, jj] = 0
                    
        sm_time =weizhen_calculate.shape[0]            
        weizhen_calculate=weizhen_calculate.reshape(-1,)
        dataIn = pd.read_csv('./mesh.csv', engine="python")
        dataIn = np.array(dataIn)
        x_axis1 = dataIn[:, 0]
        y_axis1 = dataIn[:, 1]
        x_axis = np.array(x_axis1)
        y_axis = np.array(y_axis1)
        X = np.tile(x_axis, (sm_time, 1)).astype('float64').reshape(-1,) 
        Y = np.tile(y_axis, (sm_time, 1)).astype('float64').reshape(-1,)
        atime = np.loadtxt('Day_chen.txt', delimiter=' ', dtype=str,
                           skiprows=0)
        atime = np.array(atime.astype(float))
        atime = np.tile(atime[:sm_time], (2500, 1)).astype('float64')
        Stime = atime.T.reshape(-1, ) 
        xyWeizhendata2 = np.vstack((X, Y, weizhen_calculate, Stime)).T
        juzhen = xyWeizhendata2[1:, :].astype(float) 
        length = juzhen.shape[0]
        DataResult1 = []

        n = 0
        for i in range(1, length): 
            if juzhen[i][2] != 0:
                n = n + 1
                DataResult1.append(juzhen[i])
        DataResult1 = np.array(DataResult1)
        
        if DataResult1.shape == (0,):     
            DataResult1 = np.zeros((0, 4))                    
        else:
            DataResult1 = DataResult1  
            
        length = DataResult1.shape[0]
        DataResult = DataResult1
        result = []
        n = 1
                                
        data_dict = {}

        for i in range(length):
            key = (DataResult1[i, 0], DataResult1[i, 1])
            value = DataResult1[i, 3]
            
            if key in data_dict:
                if value < data_dict[key]:
                    data_dict[key] = value
            else:
                data_dict[key] = value

        for j in range(length):
            key = (DataResult[j, 0], DataResult[j, 1])
            if key in data_dict:
                DataResult[j, 3] = min(DataResult[j, 3], data_dict[key])
                
        monidian = np.vstack((DataResult[:, 0], DataResult[:, 1],
                              DataResult[:, 2], DataResult[:, 3])).T
               
        xx = monidian[:, 0]
        yy = monidian[:, 1]
        tt = monidian[:, 3]

        fig = plt.figure(figsize = (8, 8))
        ax = fig.add_subplot(111)
       
        scatter = ax.scatter(xx, yy, c = tt, cmap = 'coolwarm', s = 50, zorder = 20)
        plt.colorbar(scatter, label = 'miceoseismic{}'.format(idf[jjj]))
        ax.set_xlim(0, 500)
        ax.set_ylim(0, 500)
        
        plt.savefig(r'./Mengdi/seismicity/k={}/scatter{}-times{}.png'.format(k, idf[jjj], sm_time))
        plt.close()
        
        xaxis = monidian[:, 0]
        yaxis = monidian[:, 1]
        occuretime = monidian[:, 3]
        xmin, xmax = 0, 500
        ymin, ymax = 0, 500
        xx, yy = np.mgrid[xmin:xmax:50j, ymin:ymax:50j]
        positions = np.vstack([xx.ravel(), yy.ravel()])
        values = np.vstack([xaxis, yaxis])
        
        if values.shape[1] == 0:
            weizhen_KDE = np.zeros((1, fracturepoints))
        else:            
            kernel = st.gaussian_kde(values) # Gaussian kernel density function
            f = np.reshape(kernel(positions).T, xx.shape)
            weizhen_KDE=f.reshape(fracturepoints,)
        fy1[jjj, :] = weizhen_KDE[:]

    save_weizhen = np.concatenate((dataIn[:, :2], fy1.T), axis = 1)
    np.savetxt(r'./Mengdi/seismicity/k={}/weizhen_KDE.txt'.format(k), save_weizhen, fmt='%.4E')
    

    folder_path22 = './Mengdi/tracer/k={}'.format(k)
    create_folder(folder_path22)
    print('React_run...')    
    fy = np.zeros([x.shape[0], tracerpoints])
    for i in range(0, x.shape[0]):        
        workspace = main_dir + "/HF_in_chain" + str(idf[i])
        path2 = workspace + './PERMEAB.OUT'
        PERMEAB = open(path2, mode='r+', encoding='utf-8')
        list = []
        list2 = []
        for line in PERMEAB.readlines()[-2501:-1]:
            list.append(line)        
        PERMEAB.seek(0)   
        for lines in PERMEAB.readlines()[2503:5003]:
            list2.append(lines)

        permeability = []
        permeability_former = []
        for ii in range(0, len(list)):
            a = list[ii].split()
            b = a[2]
            permeability.append(b)        
        for iii in range(0, len(list2)):
            p = list2[iii].split()
            q = p[2]
            permeability_former.append(q)
                        
            
        permeabilitygroup = np.array(permeability) 
        permeabilitygroup2 = np.array(permeability, dtype=float).reshape(-1,1)  
        np.savetxt(r'./Mengdi/tracer/k={}/perm_after_biot{}.txt'.format(k, i+1), permeabilitygroup2, fmt='%.4E') 
        
        permeability_former = np.array(permeability_former, dtype=float).reshape(-1,1)
        np.savetxt(r'./Mengdi/tracer/k={}/perm_before_biot{}.txt'.format(k, i+1), permeability_former , fmt='%.4E') 
        
        ratio = np.where(permeability_former != 0, permeabilitygroup2 / permeability_former, np.nan)
        ratio = ratio.reshape(-1,1)
        
        fig, axs = plt.subplots(3, 1, figsize=(8, 18))

        dataIn = pd.read_csv('./mesh.csv', engine="python")
        dataIn = np.array(dataIn)
        
        scatter1 = axs[0].scatter(dataIn[:, 0], dataIn[:, 1], c = permeability_former, cmap='viridis', s = 50)
        axs[0].set_title('Former permeability')
        plt.colorbar(scatter1, ax=axs[0])
        
        scatter2 = axs[1].scatter(dataIn[:, 0], dataIn[:, 1], c = permeabilitygroup2, cmap='viridis', s = 50)
        axs[1].set_title('Later permeability')
        plt.colorbar(scatter2, ax=axs[1])
        
        scatter3 = axs[2].scatter(dataIn[:, 0], dataIn[:, 1], c = ratio, cmap = 'coolwarm', zorder = 20, s = 50)
        axs[2].set_title('Ratio')
        plt.colorbar(scatter3, label = 'ratio{}'.format(idf[i]))
        
        plt.tight_layout()
       
        plt.savefig(r'./Mengdi/tracer/k={}/ratio{}.png'.format(k, idf[i]))
        plt.close()
        
        perm_out = np.concatenate((dataIn[:, :2], np.log10(permeability_former), np.log10(permeabilitygroup2), ratio), axis = 1)            
        
        Kworkspace = main_dir + "/tracerTest" + str(idf[i])
        KpathMesh1=Kworkspace+ '/MESH1'
        KpathMesh=Kworkspace+ '/MESH'
        f2 = open(KpathMesh1, 'r')
        f_new2 = open(KpathMesh, 'w')
        j = 0
        while j < 2500:        
            for line in f2:
                if "aaaaaaaaaa" in line:
                    line = line.replace("aaaaaaaaaa", str("{0:.4E}".format(float(permeabilitygroup[j]) * 1e15)))
                    j = j + 1                    
                f_new2.write(line)
        f2.close()
        f_new2.close()
    
    print('inject_tracer...')
    Tracer = 'Begin_tracer_chen.bat'
    subprocess.run(Tracer, shell=True)
    
    TOUGHReactName = "TracerRun1.bat"
    p=subprocess.Popen(TOUGHReactName,bufsize=2048,shell=True,stdin=subprocess.PIPE,\
                        stdout=subprocess.PIPE,close_fds=True)
    p.wait()
    result_str=p.stdout.read()  
    print('done...')
        
    print('displacement...')
    Water = 'restart_water_chen.bat'
    subprocess.run(Water,shell=True)  
     
    TOUGHReactName = "TracerRun1.bat"
    p=subprocess.Popen(TOUGHReactName,bufsize=2048,shell=True,stdin=subprocess.PIPE,\
                        stdout=subprocess.PIPE,close_fds=True)
    p.wait()
    result_str=p.stdout.read()  
    print('done...')
    
    fy = np.zeros([NPar, tracerpoints])
    for i in range(0, NPar):
        workspace = main_dir + "/tracerTest" + str(idf[i])
        path1 = workspace + '/Ttim.dat'
        data = pd.read_table(path1, header=4, sep='\s+')
        data = data.to_csv('f.csv', index=False, header=True)
        data = pd.read_csv('./f.csv', engine="python")
        data.iloc[:, 1] = data.iloc[:, 1] * 365.25
        data.iloc[:, 1] = data.iloc[:, 1].astype(int)
        data.columns
        data.rename(columns={'Time(yr)': 't', 't_26nds': 'con'}, inplace=True)
        data = data[['t', 'con']]
        groups = data.groupby(["t"]).max()
        
        groups = groups.reset_index('t')
        groups.set_index('t', inplace=True)
        full_index = pd.Index(range(int(groups.index.min()), int(groups.index.max()) + 1))
        groups2 = groups.reindex(full_index).interpolate(method='linear')
        groups2.reset_index(inplace=True)
        
        groups = np.array(groups2)
        xxx = []
        for j in range(0, groups.shape[0]):
            if groups[j, 0] % 1 == 0:
                xxx.append(groups[j, 1])
        yyy = np.array(xxx)
        y1 = yyy[0:tracerpoints]
        for ii in range(0, y1.shape[0]):
            fy[i, ii] = y1[ii]
    y = fy.reshape((NPar, -1))
    
    time = np.arange(1, 62)
    for ij in range(0, y.shape[0]):
        plt.plot(time , y[ij,:],  color='#9a9fa2')
        
    observe_tracer3 = np.loadtxt('./tracer38.txt')
    plt.scatter(time, observe_tracer3,  color='#e42313', zorder = 2)       

    plt.xlabel('Time (h)')
    plt.ylabel('C (mg/L)')
    plt.title(f'Inversion 1 - iteration {k} Breakthrough curves of IES in sythetic case ')
    plt.savefig(r'./Mengdi/tracer/Breakthrough{}.png'.format(k), dpi=600)
    plt.close()

    return fy1,y


def extract(workspace, X, idf, k):
    if not os.path.exists(workspace):
        os.makedirs(workspace)
    modelname = "EGS"
    try:
        os.remove(os.path.join(workspace, 'MT3D001.UCN'))
        os.remove(os.path.join(workspace, modelname + '.hds'))
        os.remove(os.path.join(workspace, modelname + '.cbc'))
    except:
        pass
    X=X.astype('float')
    y1,y = HF_model(modelname, workspace, X, idf, k)
    return y1,y



